# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from __future__ import absolute_import

import functools
import math

import te.platform as tbe_platform
from tbe.common.buildcfg.buildcfg import build_config
from mindspore.ops import DataType
from mindspore.ops import TBERegOp
from te import tvm
from te.utils import para_check

from accspeed.ops.dropout_mask_decoder.constants import \
    DROPOUT_MASK_DECODER_KERNEL_NAME


VEC_OP_FP16_MAX_NUM = tbe_platform.ELEMENTS_VECTOR_OP_FP16  # 128

cus_mask_decoder_op_info = TBERegOp("DropoutMaskDecoderPrimitive") \
    .fusion_type("OPAQUE") \
    .partial_flag(True) \
    .async_flag(False) \
    .binfile_name(f"{DROPOUT_MASK_DECODER_KERNEL_NAME}.so") \
    .compute_cost(10) \
    .kernel_name(DROPOUT_MASK_DECODER_KERNEL_NAME) \
    .input(0, "input_mask", False, "required", "all") \
    .output(0, "res", False, "required", "all") \
    .dtype_format(DataType.U8_Default,
                  DataType.U8_Default,
                  )\
    .get_op_info()


def _alloc(ir_builder, dtype, shape, name, scope):
    """
    decl new buffer

    Parameters
    ----------
    ir_builder : tvm.ir_builder
        Developer API of IR node builder make function.
    dtype : string
        buffer date type.
    shape : tuple of int
        buffer shape.
    name : string
        buffer name.
    scope : string
        buffer memory scope.

    Returns
    -------
    buffer : tvm.schedule.Buffer
        Symbolic data buffer.

    """
    buf_var = ir_builder.allocate(dtype, shape, name=name, scope=scope)
    new_buffer = tvm.decl_buffer(shape,
                                 buf_var.dtype,
                                 name=name,
                                 scope=scope,
                                 data=buf_var)

    return new_buffer


def _get_ub_max_elements(dtype):
    """
    how many elems can put in ub
    Parameters
    ----------
    dtype : Input data type
    :return : how many elems can put in ub
    """
    ub_size_bytes = tbe_platform.cce_conf.get_soc_spec(tbe_platform.cce_conf.UB_SIZE)
    # '8 bit = 1byte, lots of '8' below for this reason
    dtype_bytes_size = tbe_platform.cce_intrin.get_bit_len(dtype) // 8
    # '2.125 means tensor_data + tensor_zero + tensor_mask = 1+1+0.125=2.125
    total_ele = (ub_size_bytes // dtype_bytes_size - VEC_OP_FP16_MAX_NUM) // 2.125
    total_ele = int(total_ele // VEC_OP_FP16_MAX_NUM) * VEC_OP_FP16_MAX_NUM
    total_ele = int(total_ele // (32 * 8)) * 32 * 8

    return total_ele


def decode_mask(ir_builder, place_holders, loop_paras, block_offset, core_out_mask_num):
    """  place_holders = [input_mask_holder, output_mask_holder]
    """
    offsets = [0] * 6

    offsets[0] = offsets[0] + block_offset
    offsets[1] = offsets[1] + block_offset // 8

    SCOPE_UB = tbe_platform.scope_ubuf

    zero_ub = _alloc(ir_builder, 'float16', (VEC_OP_FP16_MAX_NUM,), "zero_ub", scope=SCOPE_UB)
    one_ub = _alloc(ir_builder, 'float16', (VEC_OP_FP16_MAX_NUM,), "one_ub", scope=SCOPE_UB)

    fp16_bytes = tbe_platform.get_type_bits('float16') // 8
    uint8_bits = tbe_platform.get_type_bits('uint8')
    vector_op_fp16_max_num_once = tbe_platform.VECTOR_INST_BLOCK_WIDTH // fp16_bytes
    ub_fp16_max_num = _get_ub_max_elements('float16')
    vector_op_fp16_max_num_need_mask_bytes = vector_op_fp16_max_num_once // uint8_bits

    core_loop_cnt, ub_op_fp16_loop_cnt, core_loop_remain, mask_num_tail_by_128 = loop_paras

    fp16_const_1 = tvm.const(1.0, dtype="float16")
    fp16_const_0 = tvm.const(0.0, dtype='float16')

    ir_builder.emit(
        tvm.call_extern('float16', "vector_dup",
                        zero_ub.access_ptr("rw"),  # noqa
                        fp16_const_0, 1, 1, 1, 8, 8))
    ir_builder.emit(
        tvm.call_extern("float16", "vector_dup",
                        one_ub.access_ptr("rw"),  # noqa
                        fp16_const_1, 1, 1, 1, 8, 8))

    i_mask_holder = place_holders[0]
    o_mask_holder = place_holders[1]

    if core_loop_cnt > 0:
        o_mask_ub = _alloc(ir_builder, 'float16', (ub_fp16_max_num,), "o_mask_ub", scope=SCOPE_UB)
        o_mask_ub_u8 = _alloc(ir_builder, 'uint8', (ub_fp16_max_num,), "o_mask_ub_u8",
                              scope=SCOPE_UB)
        i_mask_ub = _alloc(ir_builder, 'uint8', (ub_fp16_max_num // 8,), "i_mask_ub",
                           scope=SCOPE_UB)

        o_mask_cp_burst_len = ub_fp16_max_num // 16
        i_mask_cp_burst_len = ub_fp16_max_num // 8 // 32

        with ir_builder.for_range(0,  core_loop_cnt, name='index0') as index0:
            o_mask_wr_offset = block_offset + ub_fp16_max_num * index0
            i_mask_rd_offset = o_mask_wr_offset // 8

            ir_builder.emit(
                tvm.call_extern(
                    'uint8', "copy_gm_to_ubuf",  # copy mask to UB
                    i_mask_ub.access_ptr("w"),  # noqa
                    i_mask_holder.access_ptr("r", offset=i_mask_rd_offset), 0, 1, i_mask_cp_burst_len, 0, 0))

            with ir_builder.for_range(0, ub_op_fp16_loop_cnt, name='index1') as index1:
                ir_builder.emit(
                    tvm.call_extern(
                        'uint8', 'set_cmpmask',  # set mask according to data_mask_ub
                        i_mask_ub.access_ptr('r', offset=16 * index1))  # noqa
                )
                # vsel from alloc_res[4]/alloc_res[0] write to  alloc_res[4]
                offset = VEC_OP_FP16_MAX_NUM * index1
                ir_builder.emit(
                    tvm.call_extern('float16', 'vsel',
                                    o_mask_ub.access_ptr('w', offset=offset),  # noqa
                                    one_ub.access_ptr('r', offset=0),  # noqa
                                    zero_ub.access_ptr('r', offset=0), 1, 1, 1, 1, 0, 0, 0))# noqa
                ir_builder.emit(
                    tvm.call_extern("uint8", "vconv_f162u8",
                                    o_mask_ub_u8.access_ptr("w", offset=offset),
                                    o_mask_ub.access_ptr("r", offset=offset),
                                    1,  # repeat times
                                    1,  # dst block stride
                                    1,  # src block stride
                                    0,  # dst repeat stride
                                    0))  # src repeat stride
            ir_builder.emit(
                tvm.call_extern(
                    'float16', "copy_ubuf_to_gm",  # copy to out
                    o_mask_holder.access_ptr('w', offset=o_mask_wr_offset),
                    o_mask_ub_u8.access_ptr("r"), 0, 1, o_mask_cp_burst_len//2, 0, 0))  # noqa

        offsets[0] = offsets[0] + ub_fp16_max_num * core_loop_cnt
        offsets[1] = offsets[1] + ub_fp16_max_num * core_loop_cnt // 8

    if core_loop_remain:
        o_mask_tail_num = core_out_mask_num - ub_fp16_max_num * core_loop_cnt
        i_mask_tail_num = core_out_mask_num // 8 - ub_fp16_max_num // 8 * core_loop_cnt

        o_mask_ub = _alloc(ir_builder, 'float16', (o_mask_tail_num,), "o_mask_ub", scope=SCOPE_UB)
        o_mask_ub_u8 = _alloc(ir_builder, 'uint8', (o_mask_tail_num,), "o_mask_ub_u8",
                              scope=SCOPE_UB)
        i_mask_ub = _alloc(ir_builder, 'uint8', (i_mask_tail_num,), "i_mask_ub", scope=SCOPE_UB)

        o_mask_tail_burst_len = int(math.ceil(o_mask_tail_num * 1.0 / 16))
        i_mask_tail_rd_burst_len = int(math.ceil(i_mask_tail_num * 1.0 / 32))

        ir_builder.emit(
            tvm.call_extern(
                'uint8', "copy_gm_to_ubuf",  # gm_mask -> UB.
                i_mask_ub.access_ptr("w"),  # noqa
                i_mask_holder.access_ptr("r", offset=offsets[1]), 0, 1,
                i_mask_tail_rd_burst_len, 0, 0))

        loop_paras[1] = o_mask_tail_num // VEC_OP_FP16_MAX_NUM
        loop_paras[3] = o_mask_tail_num % VEC_OP_FP16_MAX_NUM

        loops = loop_paras[1]
        with ir_builder.for_range(0, loops, name='index2') as index2:
            ir_builder.emit(
                tvm.call_extern(
                    'uint8', 'set_cmpmask',
                    i_mask_ub.access_ptr('r', offset=16 * index2)))  # noqa

            offset = VEC_OP_FP16_MAX_NUM * index2
            ir_builder.emit(
                tvm.call_extern('float16', 'vsel',
                                o_mask_ub.access_ptr('w', offset=offset),  # noqa
                                one_ub.access_ptr('r', offset=0),  # noqa
                                zero_ub.access_ptr('r', offset=0), 1, 1, 1, 1, 0, 0, 0)) # noqa
            ir_builder.emit(
                tvm.call_extern("uint8", "vconv_f162u8",
                                o_mask_ub_u8.access_ptr("w", offset=offset),
                                o_mask_ub.access_ptr("r", offset=offset),
                                1,  # repeat times
                                1,  # dst block stride
                                1,  # src block stride
                                0,  # dst repeat stride
                                0))  # src repeat stride

        offsets[4] = vector_op_fp16_max_num_once * loops
        offsets[5] = vector_op_fp16_max_num_need_mask_bytes * loops

        if loop_paras[3]:
            tbe_platform.reset_mask_insn(ir_builder, 'float16', bits=loop_paras[3], mask_func=None)

            if loop_paras[1] == 0:
                ir_builder.emit(
                    tvm.call_extern(
                        'uint8', 'set_cmpmask',
                        i_mask_ub.access_ptr('r', offset=offsets[5])))  # noqa
                ir_builder.emit(
                    tvm.call_extern('float16', 'vsel',
                                    o_mask_ub.access_ptr('w', offset=offsets[4]),  # noqa
                                    one_ub.access_ptr('r', offset=0),  # noqa
                                    zero_ub.access_ptr('r', offset=0), 1, 1, 1, 1, 0, 0, 0)) # noqa
                ir_builder.emit(
                    tvm.call_extern("uint8", "vconv_f162u8",
                                    o_mask_ub_u8.access_ptr("w", offset=offsets[4]),
                                    o_mask_ub.access_ptr("r", offset=offsets[4]),
                                    1,  # repeat times
                                    1,  # dst block stride
                                    1,  # src block stride
                                    0,  # dst repeat stride
                                    0))  # src repeat stride

        ir_builder.emit(
            tvm.call_extern(
                'float16', "copy_ubuf_to_gm",
                o_mask_holder.access_ptr('w', offset=offsets[0]),
                o_mask_ub_u8.access_ptr("r"), 0, 1, o_mask_tail_burst_len//2, 0, 0))  # noqa


def _kernel_ir(dst, src):
    """
    dropout_do_mask kernel
    :param dst: Destination address
    :param src: Original Address
    :return: ir builder
    """
    ir_builder = tvm.ir_builder.create()
    place_holders = [src[0], dst[0]]  # input & output params

    ai_core_num, mask_128_grp_num_per_core, mask_128_grp_tail_by_core, mask_num_tail_by_128 = \
        do_tiling(src[0])

    ub_fp16_max_num = _get_ub_max_elements('float16')
    ub_op_fp16_loop_cnt = ub_fp16_max_num // VEC_OP_FP16_MAX_NUM

    blk_idx = tvm.thread_axis("blockIdx.x")
    ir_builder.scope_attr(blk_idx, "thread_extent", ai_core_num)

    with ir_builder.if_scope(blk_idx < mask_128_grp_tail_by_core):
        core_out_mask_num = (mask_128_grp_num_per_core + 1) * VEC_OP_FP16_MAX_NUM * 8
        blk_offset = core_out_mask_num * blk_idx
        # 0:loop_for_ub 1:loop_for_128
        # 2:remain_data_ub(after tilling by ub max process elements) 3:remain_ele
        core_loop_cnt = int(core_out_mask_num) // ub_fp16_max_num # must use int conversion.
        core_loop_remain = int(core_out_mask_num) % ub_fp16_max_num
        loops_paras = [
            core_loop_cnt,
            ub_op_fp16_loop_cnt,
            core_loop_remain,
            mask_num_tail_by_128
        ]
        decode_mask(ir_builder, place_holders, loops_paras, blk_offset, core_out_mask_num)
    with ir_builder.else_scope():
        # one byte i_mask output 8 masks.
        core_out_mask_num = mask_128_grp_num_per_core * VEC_OP_FP16_MAX_NUM * 8
        blk_offset = VEC_OP_FP16_MAX_NUM * 8 * mask_128_grp_tail_by_core + core_out_mask_num * blk_idx
        if mask_num_tail_by_128:
            with ir_builder.if_scope(blk_idx == ai_core_num - 1):
                core_out_mask_num += mask_num_tail_by_128 * 8
        # 0:loop_for_ub 1:loop_for_128
        # 2:remain_data_ub(after tilling by ub max process elements) 3:remain_ele
        core_loop_cnt = int(core_out_mask_num) // ub_fp16_max_num
        core_loop_remain = int(core_out_mask_num) % ub_fp16_max_num
        loops_paras = [
            core_loop_cnt,
            ub_op_fp16_loop_cnt,
            core_loop_remain,
            0]
        decode_mask(ir_builder, place_holders, loops_paras, blk_offset, core_out_mask_num)

    return ir_builder.get()


def do_tiling(input_mask):
    """Do tiling according to the input mask shape.

    :param input_mask: input dropout mask of uint8
    :return:
    ai_core_num, mask_128_grp_num_per_core, mask_128_grp_tail_by_core, mask_num_tail_by_128, aligned
    """
    mask_shape = input_mask.shape[:]
    ai_core_num = tbe_platform.cce_conf.get_soc_spec(tbe_platform.cce_conf.CORE_NUM)

    i_mask_num = mask_shape[0]
    # 按 128B 分组
    mask_128_grp_num = int(i_mask_num // 128)
    mask_num_tail_by_128 = int(i_mask_num % 128)
    mask_128_grp_tail_by_core = 0

    if mask_128_grp_num <= ai_core_num:
        ai_core_num = mask_128_grp_num if mask_128_grp_num != 0 else 1
        mask_128_grp_num_per_core = 1
        return ai_core_num, mask_128_grp_num_per_core, mask_128_grp_tail_by_core, mask_num_tail_by_128

    # 将 input_mask 按 128 分组后再分核。
    mask_128_grp_num_per_core = mask_128_grp_num // ai_core_num
    mask_128_grp_tail_by_core = mask_128_grp_num % ai_core_num

    return ai_core_num, mask_128_grp_num_per_core, mask_128_grp_tail_by_core, mask_num_tail_by_128


def decode_dropout_mask(input_mask, output, kernel_name=DROPOUT_MASK_DECODER_KERNEL_NAME):
    """
    algorithm: decode every input U8 byte to 8 bytes U8 mask, for example:
    input:231 (0b 1110 0111)--->output: [1, 1, 1, 0, 0, 1, 1, 1]

    Parameters
    ----------
    input_mask : dict,shape and dtype of input_mask
        dtype should be uint8 or uint1
        if dtype is uint8, mask is 1D,
            length=(size(shape_tensor)+VEC_OP_FP16_MAX_NUM
            -1)/VEC_OP_FP16_MAX_NUM*VEC_OP_FP16_MAX_NUM/8
            eg. shape_tensor=[2,5,8] shape_mask=[16] shape_res=[2,5,8]
            shape_tensor=[15,17,19] shape_mask=[608] shape_res=[15,17,19]
    output : dict,shape and dtype of output
    kernel_name : str
        cce kernel name, default value is "dropout_mask_decoder_impl"

    Returns
    -------
    None
    """
    shape_mask = input_mask.get("shape")
    dtype_mask = input_mask.get("dtype")

    para_check.check_dtype(dtype_mask.lower(), ["uint8"], param_name="input_mask")

    data_mask = tvm.placeholder(
        (functools.reduce(lambda x, y: x*y, shape_mask), ),
        dtype='uint8',
        name="input_mask")

    res = tvm.extern([shape_mask],
                     [data_mask],
                     lambda ins, outs: _kernel_ir(outs, ins),
                     name="res",
                     dtype='uint8')

    tensor_list = [data_mask, res]
    schedule = tvm.create_schedule(res.op)

    with build_config():
        tvm.build(schedule, tensor_list, "cce", name=kernel_name)
